import clip
import torch
import numpy as np
from PIL import Image

from label_reward import center_crop


device = "cuda" if torch.cuda.is_available() else "cpu"


def get_torch_clip_reward(clip_model, obs, pos_text, use_crop=False):
    model, preprocess = clip_model
    if use_crop:
        obs = center_crop(obs[None, ...], (obs.shape[0] // 2, obs.shape[0] // 2))[0]
    image = preprocess(Image.fromarray(np.array(obs))).unsqueeze(0).to(device)
    text = clip.tokenize(pos_text).to(device)
    with torch.no_grad():
        _, logits_per_text = model(image, text)
    if isinstance(pos_text, list):
        clip_reward = logits_per_text.mean(axis=0)
    else:
        clip_reward = logits_per_text[0]
    return clip_reward.float().detach().cpu().numpy()


def get_torch_clip_adapter_reward(clip_model, obs, pos_text, use_crop=False):
    model, preprocess = clip_model
    if use_crop:
        obs = center_crop(obs[None, ...], (obs.shape[0] // 2, obs.shape[0] // 2))[0]
    image = preprocess(Image.fromarray(np.array(obs))).unsqueeze(0).to(device)
    text = clip.tokenize(pos_text).to(device)
    with torch.no_grad():
        encoded_image = model.encode_image(image)
        encoded_text = model.encode_text(text)
        logit_scale = model.logit_scale.exp() if hasattr(model, 'logit_scale') else model.clip_model.logit_scale.exp()
        logit = (logit_scale * (encoded_image @ encoded_text.T)).t()
    
    if isinstance(pos_text, list):
        clip_reward = logit.mean(axis=0)
    else:
        clip_reward = logit[0]

    return clip_reward.float().detach().cpu().numpy()


def get_torch_ts2net_reward(
    clip_model,
    video,
    pos_seq_output,
    pos_input_mask
):
    model, _ = clip_model
    video_mask = torch.ones(video.shape[0], device=device)

    video = video.reshape(1, 1, 1, *video.shape)
    vid_output = model.get_visual_output(video, video_mask)

    pos_reward = model.get_final_similarity(pos_seq_output, vid_output, pos_input_mask, video_mask, loose_type=True).squeeze().float().detach().cpu().numpy()
    return pos_reward


def get_torch_mugen_reward(
    clip_model,
    video,
    pos_seq_output
):
    model, _ = clip_model
    vid_output = model.get_video_embedding({"video": video})
    reward = (vid_output @ pos_seq_output.T) * torch.exp(model.temperature)
    if pos_seq_output.shape[0] == 1:
        return reward.squeeze(-1).float().detach().cpu().numpy()
    else:
        return reward.mean(axis=-1).float().detach().cpu().numpy()

